/*
 * IBM Corporation.
 * Copyright (c) 2014 All Rights Reserved.
 */

package com.ibm.iisp.common.security.ws;

import java.rmi.RemoteException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.logging.Level;
import java.util.logging.Logger;

import javax.naming.Context;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import javax.sql.DataSource;

import com.ibm.websphere.security.PasswordCheckFailedException;

/**
 * 类作用：
 * @author Johnny@cn.ibm.com
 * 使用说明：
 */
public class DBPasswordChecker {
	Logger log = Logger.getLogger(this.getClass().getName());
	public static class CacheItem {
		private long time;
		private String value;

		/**
		 * @param time
		 * @param value
		 */
		public CacheItem(long time, String value) {
			super();
			this.time = time;
			this.value = value;
		}

		public long getTime() {
			return time;
		}

		public void setTime(long time) {
			this.time = time;
		}

		public String getValue() {
			return value;
		}

		public void setValue(String value) {
			this.value = value;
		}

	}

	private String dataSourceName;
	private DataSource ds;
	/**
	 * 缓存有效时间为600秒，主要防止压力测试时频繁登录。
	 */
	private static final long cacheLiveTime = 600 * 1000L;
	final String sql = "select PASSWD from IISP_AUTH_USER where USER_CODE = ? ";

	/**
	 * @param dataSourceName
	 */
	public DBPasswordChecker(String dataSourceName) {
		super();
		this.dataSourceName = dataSourceName;
	}

	private HashMap<String, CacheItem> passCache = new HashMap<>(100);

	public String checkPassword(String username, String password)
		throws PasswordCheckFailedException, RemoteException {
		if (password == null) {
			password = "";
		}
		if (!"".equals(password)) {
			password = SecurityUtils.MD5(password);
		}
		CacheItem cacheItem = passCache.get(username);
		if (cacheItem != null) {
			if (System.currentTimeMillis() - cacheItem.getTime() < cacheLiveTime) {
				if (cacheItem.getValue().equals(password)) {
					return username;
				}
				throw new PasswordCheckFailedException(username + "'s password is wrong");
			}
		}
		String dPass = getDbPassword(username);
		if (dPass != null) {
			cacheItem = new CacheItem(System.currentTimeMillis(), dPass);
			passCache.put(username, cacheItem);
		} else {
			throw new PasswordCheckFailedException(username + " not found.");
		}
		if (dPass.equals(password)) {
			if (log.isLoggable(Level.FINE)) {
				log.fine("password check for user:{} " + username + " is passed");
			}
			return username;
		}
		throw new PasswordCheckFailedException(username + "'s password is wrong");
	}

	/**
	 * @param username
	 * @return user password from db. null if user not found.
	 * @throws RemoteException
	 */
	protected String getDbPassword(String username) throws RemoteException {
		String dPass = "";
		try {
			if (ds == null) {
				Context ctx = new InitialContext();
				ds = (DataSource) ctx.lookup(dataSourceName);
			}
			PreparedStatement ps = null;
			ResultSet rs = null;
			if (log.isLoggable(Level.INFO))
				log.info("sql: " + sql + ", bind => " + username);
			try (Connection con = ds.getConnection();) {
				ps = con.prepareStatement(sql);
				ps.setString(1, username);
				rs = ps.executeQuery();
				if (rs.next()) {
					dPass = rs.getString(1);
					if (dPass == null) {
						dPass = "";// 返回空串，以免null被误认为用户没找到
					}
				}
				rs.close();
				rs = null;
				ps.close();
				ps = null;
			} catch (SQLException e) {
				log.severe("Error to get data from db" + e);
				throw new RemoteException("Error to get data from db", e);
			} finally {
				if (rs != null) {
					try {
						rs.close();
					} catch (SQLException e1) {
						log.warning("Error to close resultset" + e1);
					}
				}
				if (ps != null) {
					try {
						ps.close();
					} catch (SQLException e2) {
						log.warning("Error to close pstatement" + e2);
					}
				}
			}
		} catch (NamingException e3) {
			log.severe("Error to get data source" + dataSourceName);
			throw new RemoteException("Error to get data source" + dataSourceName, e3);
		}
		return dPass;
	}

}
